""" Dataloader for prostate MRIs """
import h5py
import numpy as np
from PIL import Image as im
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader


class HDF5Dataset(Dataset):
    """
    Input params:
        path: Path to the folder containing the dataset (one or multiple HDF5 files).
        samples_id: Samples id to fetch.
        labels: Np array of labels
        transform: PyTorch transform to apply to every data instance (default = None).
    """

    def __init__(self, path, samples_id, labels, set_type, transform=None):
        super().__init__()
        # Input:
        self.transform = transform
        self.images_path = path
        self.samples_id = samples_id
        # Return:
        self.dwi1_images = []
        self.t2ax_images = []
        self.labels = []
        self.labels = labels.tolist()
        self.set_type = set_type
        # self.original_index = np.arange(len(self.labels))

        for sample in self.samples_id:
            img = h5py.File(self.images_path + "/" + str(sample) + ".hdf5", "r")
            self.dwi1_images.append(img["dwi1"][:])
            self.t2ax_images.append(img["t2ax_cropped"][:])

    def __getitem__(self, index):
        dwi1 = self.dwi1_images[index]
        t2ax = self.t2ax_images[index]
        label = self.labels[index]
        # original_index = self.original_index[index]

        # Transformations as NP array
        dwi1 = self.min_max_normalize(dwi1)
        dwi1 = dwi1.transpose(0, 3, 1, 2)
        t2ax = self.t2ax_images[index]
        t2ax = t2ax.transpose(2, 0, 1)
        t2ax = np.expand_dims(t2ax, axis=0)
        if self.set_type == "train":
            dwi1, t2ax = self.flip_transform(dwi1, t2ax, 0.5)
        # label = np.expand_dims(label, axis=0)

        dwi1 = dwi1.astype(np.float32)
        t2ax = t2ax.astype(np.float32)
        label = np.asarray(label).astype(np.float32)

        if self.transform != None:
            dwi1 = self.transform(np.squeeze(dwi1).astype(np.float32))
            t2ax = self.transform(np.squeeze(t2ax).astype(np.float32))

        return dwi1, t2ax, label  # , original_index

    def __len__(self):
        return len(self.labels)

    def min_max_normalize(self, I, min_p=20, max_p=80):
        if I.ndim == 3:
            mI = np.percentile(I, min_p)
            MI = np.percentile(I, max_p)
            return ((I - mI) / (MI - mI) - 0.5) * 5.0
        elif I.ndim == 4:
            K = np.zeros_like(I)
            for j in range(I.shape[0]):
                K[j] = self.min_max_normalize(I[j])
            return K

    def flip_transform(self, dwi, t2ax, prob_flip=0.2):

        a = np.random.rand()
        if a < prob_flip:
            return np.flip(dwi, axis=3).copy(), np.flip(t2ax, axis=3).copy()
        else:
            return dwi, t2ax
